##
## Licensed to the .NET Foundation under one or more agreements.
## The .NET Foundation licenses this file to you under the MIT license.
##

import os
from datetime import datetime

from utils import next_power_of_2
from bitonic_isa import BitonicISA


scalar_native_size_map = {
    "int64_t": 8,
    "uint64_t": 8,
    "double": 8,
    "int32_t": 4,
    "uint32_t": 4,
    "float": 4,
}

class ScalarBitonicISA(BitonicISA):
    def __init__(self, type):
        self.vector_size_in_bytes = 16

        self.type = type

        self.bitonic_type_map = {
            "int64_t": "int64_t_v",
            "uint64_t": "uint64_t_v",
            "double": "double_v",
            "int32_t": "int32_t_v",
            "uint32_t": "uint32_t_v",
            "float": "float_v",
        }

        self.bitonic_func_suffix_type_map = {
            "int64_t": "s64",
            "uint64_t": "u64",
            "double": "f64",
            "int32_t": "s32",
            "uint32_t": "u32",
            "float": "f32",
        }

    def max_bitonic_sort_vectors(self):
        return 16

    def vector_size(self):
        return 2

    def vector_type(self):
        return self.bitonic_type_map[self.type]

    @classmethod
    def supported_types(cls):
        return scalar_native_size_map.keys()

    def generate_param_list(self, start, numParams):
        return str.join(", ", list(map(lambda p: f"d{p:02d}", range(start, start + numParams))))

    def generate_param_def_list(self, numParams):
        t = self.type
        return str.join(", ", list(map(lambda p: f"{self.vector_type()}& d{p:02d}", range(1, numParams + 1))))

    def generate_sort_vec(self, cmptype, left, right):
        return f"""
        if ({left}.a {cmptype} {right}.a) std::swap({left}.a, {right}.a);
        if ({left}.b {cmptype} {right}.b) std::swap({left}.b, {right}.b);"""

    def autogenerated_blabber(self):
        return f"""// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

/////////////////////////////////////////////////////////////////////////////
////
// This file was auto-generated by a tool at {datetime.now().strftime("%F %H:%M:%S")}
//
// It is recommended you DO NOT directly edit this file but instead edit
// the code-generator that generated this source file instead.
/////////////////////////////////////////////////////////////////////////////"""

    def generate_prologue(self, f):
        t = self.type
        s = f"""{self.autogenerated_blabber()}

#ifndef BITONIC_SORT_SCALAR_{t.upper()}_H
#define BITONIC_SORT_SCALAR_{t.upper()}_H

#include "bitonic_sort.h"

namespace vxsort {{
namespace smallsort {{

struct {t}_v
{{
 {t} a;
 {t} b;
}};

template<> struct bitonic<{t}, scalar> {{
    static const int N = {self.vector_size()};
    static constexpr {t} MAX = std::numeric_limits<{t}>::max();
public:
"""
        print(s, file=f)

    def generate_epilogue(self, f):
        s = f"""
}};
}}
}}

#endif // BITONIC_SORT_SCALAR

    """
        print(s, file=f)

    def generate_1v_basic_sorters(self, f, ascending):
        g = self

        suffix = "ascending" if ascending else "descending"
        cmptype = ">" if ascending else "<"

        s = f"""    static INLINE void sort_01v_{suffix}({g.generate_param_def_list(1)}) {{
        if (d01.a {cmptype} d01.b) std::swap(d01.a, d01.b);
    }}\n"""
        print(s, file=f)


    def generate_1v_merge_sorters(self, f, ascending: bool):
        g = self

        suffix = "ascending" if ascending else "descending"
        cmptype = ">" if ascending else "<"

        s = f"""    static INLINE void sort_01v_merge_{suffix}({g.generate_param_def_list(1)}) {{
        if (d01.a {cmptype} d01.b) std::swap(d01.a, d01.b);
    }}\n"""
        print(s, file=f)

    def generate_compounded_sorter(self, f, width, ascending, inline):
        g = self

        w1 = int(next_power_of_2(width) / 2)
        w2 = int(width - w1)

        suffix = "ascending" if ascending else "descending"
        rev_suffix = "descending" if ascending else "ascending"
        cmptype = ">" if ascending else "<"

        inl = "INLINE" if inline else "NOINLINE"

        s = f"""    static {inl} void sort_{width:02d}v_{suffix}({g.generate_param_def_list(width)}) {{
        sort_{w1:02d}v_{suffix}({g.generate_param_list(1, w1)});
        sort_{w2:02d}v_{rev_suffix}({g.generate_param_list(w1 + 1, w2)});"""

        print(s, file=f)

        for r in range(w1 + 1, width + 1):
            x = w1 + 1 - (r - w1)
            s = self.generate_sort_vec(">", f"d{x:02d}", f"d{r:02d}")
            print(s, file=f)

        s = f"""
        sort_{w1:02d}v_merge_{suffix}({g.generate_param_list(1, w1)});
        sort_{w2:02d}v_merge_{suffix}({g.generate_param_list(w1 + 1, w2)});"""
        print(s, file=f)
        print("    }\n", file=f)


    def generate_compounded_merger(self, f, width, ascending, inline):
        g = self

        w1 = int(next_power_of_2(width) / 2)
        w2 = int(width - w1)

        suffix = "ascending" if ascending else "descending"
        rev_suffix = "descending" if ascending else "ascending"
        cmptype = ">" if ascending else "<"

        inl = "INLINE" if inline else "NOINLINE"

        s = f"    static {inl} void sort_{width:02d}v_merge_{suffix}({g.generate_param_def_list(width)}) {{"
        print(s, file=f)

        for r in range(w1 + 1, width + 1):
            x = r - w1
            s = self.generate_sort_vec(">", f"d{x:02d}", f"d{r:02d}")
            print(s, file=f)

        s = f"""
        sort_{w1:02d}v_merge_{suffix}({g.generate_param_list(1, w1)});
        sort_{w2:02d}v_merge_{suffix}({g.generate_param_list(w1 + 1, w2)});"""
        print(s, file=f)
        print("    }\n", file=f)



    def generate_entry_points(self, f):
        type = self.type
        vt = self.bitonic_type_map[type];
        ft = self.bitonic_func_suffix_type_map[type];
        g = self

        for m in range(1, g.max_bitonic_sort_vectors() + 1):
            s = f"""
    static NOINLINE void sort_{m:02d}v_alt({type} *ptr, int remainder) {{"""
            print(s, file=f)
            for l in range(0, m):
                s = f"        {g.vector_type()} d{l + 1:02d};"
                print(s, file=f)
            print("", file=f)

            for l in range(0, m-1):
                s = f"""        d{l + 1:02d}.a = *(ptr + {l*2});
        d{l + 1:02d}.b = *(ptr + {l*2 + 1});"""
                print(s, file=f)

            s = f"        d{m:02d}.a = *(ptr + {2*(m-1)});"
            print(s, file=f)
            s = f"        d{m:02d}.b = (remainder == 0) ? *(ptr + {2*(m-1) + 1}) : MAX;"
            print(s, file=f)

            s = f"\n        sort_{m:02d}v_ascending({g.generate_param_list(1, m)});\n"
            print(s, file=f)

            for l in range(0, m-1):
                s = f"""        *(ptr + {l*2}) = d{l + 1:02d}.a;
        *(ptr + {l*2 + 1}) = d{l + 1:02d}.b;"""
                print(s, file=f)

            s = f"        *(ptr + {2*(m-1)}) = d{m:02d}.a;"
            print(s, file=f)
            s = f"        if (remainder == 0) *(ptr + {2*(m-1) + 1}) = d{m:02d}.b;"
            print(s, file=f)

            s = f"    }}"
            print(s, file=f)


    def generate_master_entry_point(self, f_header, f_src):
        basename = os.path.basename(f_header.name)
        s = f"""// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

#include "common.h"
#include "{basename}"

using namespace vxsort;
"""
        print(s, file=f_src)

        t = self.type
        g = self

        s = f"""    static void sort({t} *ptr, size_t length);"""
        print(s, file=f_header)

        s = f"""void vxsort::smallsort::bitonic<{t}, vector_machine::scalar >::sort({t} *ptr, size_t length) {{
    const auto fullvlength = length / N;
    const int remainder = (int) (length - fullvlength * N);
    const auto v = fullvlength + ((remainder > 0) ? 1 : 0);
    switch(v) {{"""
        print(s, file=f_src)

        for m in range(1, self.max_bitonic_sort_vectors() + 1):
            s = f"        case {m}: sort_{m:02d}v_alt(ptr, remainder); break;"
            print(s, file=f_src)
        print("    }", file=f_src)
        print("}", file=f_src)

        pass
